"""
策略查看器的后端服务。
"""
import json
import os
import logging
from logging.handlers import TimedRotatingFileHandler
import copy
import pickle
import time
from datetime import datetime
import base64
from glob import iglob

import pandas as pd
from dotenv import load_dotenv
import yaml
from flask import Flask, abort, request, send_from_directory, jsonify, send_from_directory
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import text
from function import get_rb_trade_data

app = Flask(__name__)
data_path_all = {}  # 记录所有的数据文件路径，方便直接读取


def init_logger(app,
                file_name='./log/server.log',
                console_level="INFO",
                format="%(asctime)s-%(name)s-%(levelname)s-%(filename)s-%(lineno)d-%(message)s",
                file_level="DEBUG"):
    """
    初始化logger，添加文件日志和控制台日志输出。
    - app: flaskapp对象。
    - file_name: 文件日志的文件名，如果为None，则不输出到文件。
    - console_level: 控制台输出日志的级别，默认为DEBUG。
    - format: 日志格式，默认为时间-名称-级别-文件名-行号-消息。
    - file_level: 文件日志的级别，默认为INFO。
    """
    app.logger.setLevel(logging.DEBUG)  # 设置最低日志级别为DEBUG
    fm = logging.Formatter(format)

    fh = TimedRotatingFileHandler(file_name, when='midnight', backupCount=0)
    fh.setFormatter(fm)
    fh.setLevel(file_level)
    app.logger.addHandler(fh)

    # 控制台日志处理器
    ch = logging.StreamHandler()
    ch.setFormatter(fm)
    ch.setLevel(console_level)
    app.logger.addHandler(ch)


def load_config(app, filename='config.yaml'):
    with open(filename, 'r') as config_file:
        config_data = yaml.safe_load(config_file)
        app.config.update(config_data)

    load_dotenv()
    app.config['MYSQL_PASSWORD'] = os.getenv('MYSQL_PASSWORD')


# 加载配置
load_config(app)
# 初始化日志
init_logger(app)
# 初始化数据库连接
user = app.config['db_user']
password = app.config['MYSQL_PASSWORD']
host = app.config['db_host']
database = app.config['db_database']
app.config['SQLALCHEMY_DATABASE_URI'] = f"mysql+pymysql://{user}:{password}@{host}/{database}"
db = SQLAlchemy(app)


@app.route("/")
def index():
    return "Hello!"


@app.route("/qs_file/", methods=['GET'])
def qs_file():
    selected_stg = request.args.get('stg', '')
    if not selected_stg:
        app.logger.info(f'没有传递策略参数')
        return abort(400)  # 没有的策略
    if not os.path.exists(app.config[f'stg_{selected_stg}_path'] + '/' + app.config['quantstats_file_name']):
        app.logger.debug(app.config[f'stg_{selected_stg}_path'])
        app.logger.debug(app.config['quantstats_file_name'])
        app.logger.warning(f'策略{selected_stg}的qs文件不存在！')
        abort(400)
    return send_from_directory(app.config[f'stg_{selected_stg}_path'], app.config['quantstats_file_name'])


def read_from_marker(filename, marker):
    """读取从特定字符串开始的行的之后的行直到文件尾"""
    # 创建一个空的列表来存储结果
    result = []
    with open(filename, 'r') as file:
        # 从文件末尾开始读取
        lines = file.readlines()
        # 倒序查找包含特定字符串的行
        for line in reversed(lines):
            # 一旦找到含有标记的行，开始收集所有后续行
            result.append(line)
            if marker in line:
                break
        # 由于是倒序读取和存储，最后需要将结果反转回正确的顺序
        result.reverse()
    return result


@app.route("/backtest_result", methods=['GET'])
def backtest_result():
    stg = request.args.get('stg', '')
    if not stg:
        app.logger.error('查看账户没有策略名')
        abort(400)
    file = app.config[f"stg_{stg}_path"] + '/' + app.config['stg_rb_cci_backtest_result']
    if not os.path.exists(file):
        app.logger.warning(f'文件不存在！')
        abort(400)
    return read_from_marker(file, '运行结束')  # 从运行结束的记录开始往后读


@app.route("/code_list", methods=['GET'])
def get_code_list():
    if not data_path_all:
        if app.config['stock_path']:
            for f in iglob(app.config['stock_path'] + '*.csv'):
                data_path_all[os.path.basename(f).replace('.csv', '')] = {"path": f, "type": 'stock'}

        if app.config['crypto_path']:
            for f in iglob(app.config['crypto_path'] + '*.csv'):
                data_path_all[os.path.basename(f).replace('.csv', '')] = {"path": f, "type": 'crypto'}

        if app.config['futures_path']:
            for f in iglob(app.config['futures_path'] + '*.csv'):
                data_path_all[os.path.basename(f).replace('.csv', '').split('.')[1]] = \
                    {"path": f, "type": 'future'}

    return jsonify(list(data_path_all.keys()))


@app.route("/kline_option", methods=['GET'])
def get_kline_option():

    # 获取get的参数
    instrument = request.args.get('symbol', '')  # 标的代码
    selected_stg = request.args.get('stg', '')  # 策略
    app.logger.debug(f'标的是：{instrument}, 策略是{selected_stg}')

    trade_file = os.path.join(app.config[f'stg_{selected_stg}_path'], app.config['trade_file_name'])
    trade_record = pd.read_csv(trade_file)
    app.logger.debug(f'交易记录：{trade_file} 读取成功。')

    plot_config_file = app.config[f'stg_{selected_stg}_plot_config_path']
    with open(plot_config_file, 'r') as f:
        plot_config = json.load(f)
    app.logger.debug(f'画图配置：{plot_config_file} 读取成功。')

    factor_file = os.path.join(app.config[f'stg_{selected_stg}_path'], app.config['factor_feed_file'])
    with open(factor_file, 'rb') as f:
        _ = pickle.load(f)
        _ = _.get_all()
        factor_feed = _
    app.logger.debug(f'因子数据：{factor_file} 读取成功。')

    p = os.path.dirname(os.path.abspath(__file__))
    with open(p + '/static/template/backtest_kline_option.json', 'r') as f:
        option = json.load(f)
    if instrument != '':
        ins_info = data_path_all[instrument]
        if os.path.exists(ins_info['path']):
            app.logger.debug(f'标的路径是：{ins_info["path"]}')
            if ins_info['type'] == 'future':  # 分不同标的来处理数据，主要是数据文件长得不一样
                data = pd.read_csv(ins_info['path'])
                option["series"][1]["data"] = data["volume"].tolist()
                option["xAxis"][0]["data"] = data['date'].tolist()
                option["xAxis"][1]["data"] = data['date'].tolist()
                option["series"][0]["data"] = data.loc[:, ['open', 'close', 'low', 'high']].to_dict('split')['data']
            elif ins_info['type'] == 'stock':
                data = pd.read_csv(ins_info['path'], encoding='gbk', skiprows=1)
                option["series"][1]["data"] = data["成交量"].tolist()
                option["xAxis"][0]["data"] = data['交易日期'].tolist()
                option["xAxis"][1]["data"] = data['交易日期'].tolist()
                option["series"][0]["data"] = \
                    data.loc[:, ['开盘价', '收盘价', '最低价', '最高价']].to_dict('split')['data']
            elif ins_info['type'] == 'crypto':
                data = pd.DataFrame()
            else:
                app.logger.error(f'错误的标的类型：{ins_info["type"]}')
                return jsonify({"error": "Not Found", "message": f'错误的标的类型：{ins_info["type"]}'}), 404
            if factor_feed and instrument in factor_feed and plot_config:
                app.logger.debug(f'使用的画图配置是：{plot_config}')
                series_auxiliary = {
                    "type": "line",
                    "name": "factor",
                    "xAxisIndex": 2,
                    "yAxisIndex": 2,
                    # "color": "gray",
                    "data": [],
                    "lineStyle": {
                        "width": 1
                    }
                }
                series_main = {
                    "type": "line",
                    "name": "factor",
                    "data": [],
                    "lineStyle": {
                        "width": 1
                    }
                }
                for k, v in plot_config.items():
                    if v['where'] == "main":
                        tp = copy.deepcopy(series_main)
                        tp["data"] = factor_feed[instrument][k].bfill().tolist()
                        tp['name'] = k
                        option["series"].append(tp)
                    elif v['where'] == "auxiliary":
                        tp = copy.deepcopy(series_auxiliary)
                        tp["data"] = factor_feed[instrument][k].bfill().tolist()
                        tp['name'] = k
                        option["series"].append(tp)
                        option["xAxis"][2]["data"] = data['date'].tolist()
                    else:
                        app.logger.error(f"显示参数不正确！{v['where']}")
                        return jsonify({"error": "internal error", "message": "显示参数不正确"}), 500
                    option['legend']['data'].append(k)

        else:
            app.logger.error(f"文件{ins_info['path']}不存在，请重新输入！")
            return jsonify({"error": "Not Found", "message": "Data not found."}), 404
    if not trade_record.empty:
        app.logger.debug('append 交易记录~')
        point = {
            "coord": ["2023-05-25", 10],
            "symbol": "arrow",
            "symbolSize": 12,
            "symbolRotate": 180,
            "label": {"formatter": "123"},
            "itemStyle": {"color": "grey"}
        }
        for i, row in trade_record.iterrows():
            _ = copy.deepcopy(point)
            _['coord'] = [row['date'], row['price']]
            if row['action'] == "BUY":
                _['label']['formatter'] = f"{row['volume']}"
                _['symbolRotate'] = 0
                _['itemStyle']['color'] = 'red'

            elif row['action'] == "SELL_SHORT":
                _['label']['formatter'] = f"{row['volume']}"
                _['itemStyle']['color'] = 'green'

            elif row['action'] == "SELL":
                _['label']['formatter'] = f"{row['volume']}"
            elif row['action'] == "BUY_TO_COVER":
                _['symbolRotate'] = 0
                _['label']['formatter'] = f"{row['volume']}"
            option["series"][0]["markPoint"]["data"].append(_)

    return jsonify(option)


@app.route("/backtest_echart_option")
def get_backtest_echart_option():
    stg = request.args.get('stg')
    log_or_plain = request.args.get('log_or_plain')
    frequency = request.args.get('frequency')
    p = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(p, 'static/template/backtest_echarts_option.json')) as f:
        option = json.load(f)

    path = os.path.join(app.config[f'stg_{stg}_path'], app.config['image_data_name'])
    with open(path, 'rb') as f:
        res = pickle.load(f)

    res = pd.DataFrame(res)
    res['datetime'] = pd.to_datetime(res['datetime'])
    res.set_index('datetime', inplace=True)
    if frequency == 'DAY':

        res = res.resample('D').mean()
        res = res.dropna()
    res.reset_index(inplace=True)
    res['datetime'] = res['datetime'].apply(str)
    res = res.to_dict(orient='list')

    option['xAxis'][0]["data"] = res['datetime']
    option['xAxis'][1]["data"] = res['datetime']
    option['xAxis'][2]["data"] = res['datetime']
    option['series'][1]["data"] = res['draw_down']
    option['series'][2]["data"] = res['equity_vs_market_value']
    option['series'][3]["data"] = res['benchmark']
    if log_or_plain == 'Log':
        option['series'][0]["data"] = res['equity_log']
    else:
        option['series'][0]["data"] = res['equity']
        option['yAxis'][0]['type'] = 'value'
    return option


@app.route("/get_account", methods=['GET'])
def get_account():
    """获取最新的account数据"""
    stg = request.args.get('stg', '')
    stg = stg.upper()
    if not stg:
        app.logger.error('查看账户没有策略名')
        abort(400)
    sql = f"""SELECT * FROM {stg}.account ORDER BY id DESC LIMIT 2"""  # 取最新的一条数据也就是最新的两条。
    result = db.session.execute(text(sql))
    res = []
    for row in result:
        tp = row._asdict()
        tp['datetime'] = tp['datetime'].strftime("%Y-%m-%D %H:%M")
        res.append(tp)
    return jsonify(res)


@app.route("/get_position", methods=['GET'])
def get_position():
    """获取最新的持仓数据"""
    stg = request.args.get('stg', '')
    stg = stg.upper()
    if not stg:
        app.logger.error('查看账户没有策略名')
        abort(400)
    sql = f"""
    select * from {stg}.position where datetime = (select datetime from {stg}.position order by id desc limit 1)
    """
    result = db.session.execute(text(sql))
    res = []
    for row in result:
        tp = row._asdict()
        tp['datetime'] = tp['datetime'].strftime("%Y-%m-%D %H:%M")
        res.append(tp)

    return jsonify(res)


@app.route("/get_day_compare")
def get_day_compare():
    stg = request.args.get('stg', '')
    STG = stg.upper()
    if not stg:
        app.logger.error('没有策略名!')
        abort(400)

    # 获取持仓
    sql_pos = f"""
        select * from {STG}.position where datetime = (select datetime from {STG}.position order by id desc limit 1)
        """
    res = db.session.execute(text(sql_pos))
    pos = []
    for row in res:
        tp = row._asdict()
        tp['datetime'] = tp['datetime'].strftime("%Y-%m-%D %H:%M")
        pos.append(tp)
    pos = pd.DataFrame(pos)

    # 获取账户情况
    sql_account = f"""SELECT * FROM {STG}.account WHERE
    HOUR(datetime) = 15 ORDER BY id DESC LIMIT 2"""  # 取15点的最新两条数据。以计算收益率。
    result = db.session.execute(text(sql_account))
    account = pd.DataFrame(result)
    if account.shape[0] != 2:
        app.logger.debug(f'获取数据有误，不是两条！')
        abort(500)
    day_return = (account['balance'].iloc[-1] / account['balance'].iloc[0] - 1) * 100

    # 当日交易的对比。
    out_put, cmp, od, merged_df, trade_record, real_trade = compare(app, stg)
    ret = {
        "out_put": pickle_encode(out_put),
        "cmp": pickle_encode(cmp),
        "od": od,
        "merged_df": pickle_encode(merged_df),
        "trade_record": pickle_encode(trade_record),
        "real_trade": pickle_encode(real_trade),
        "day_return": day_return,
        "pos": pickle_encode(pos),
        "account": pickle_encode(account)
    }
    return jsonify(ret)


@app.route("/get_his_compare")
def get_his_compare():
    """
    获取历史对比
    :return:
    """
    stg = request.args.get('stg', '')
    days = int(request.args.get('days', 20))
    if not stg:
        app.logger.error('没有策略名!')
        abort(400)

    # echart 配置
    p = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(p, 'static/template/pnl_conpare_to_real_option.json')) as f:
        option = json.load(f)

    # 回测数据
    path = os.path.join(app.config[f'stg_{stg}_path'], app.config['image_data_name'])
    with open(path, 'rb') as f:
        res = pickle.load(f)

    res = pd.DataFrame(res)
    res['datetime'] = pd.to_datetime(res['datetime'])
    res.set_index('datetime', inplace=True)
    res = res.resample('D').mean()
    res = res.dropna()
    res.reset_index(inplace=True)
    res['date'] = res['datetime'].dt.date

    sql = f"""SELECT * FROM {stg.upper()}.account"""
    result = db.session.execute(text(sql))
    account = []
    for row in result:
        tp = row._asdict()
        account.append(tp)
    account = pd.DataFrame(account)
    account['hour'] = account['datetime'] .dt.hour
    account = account.loc[account['hour'] == 15]  # 只要15点的，23点的过滤掉
    account['date'] = account['datetime'].dt.date

    res = pd.merge(res, account, how='outer', left_on='date', right_on='date')
    res['date'] = res['date'].apply(str)
    res.dropna(inplace=True)
    res = res.to_dict(orient='list')

    option['xAxis']["data"] = res['date']
    option['series'][0]["data"] = res['balance']
    option['series'][1]["data"] = res['equity']

    # 历史对比。
    out_put, cmp, od, merged_df, trade_record, real_trade = compare(app, stg, days)
    ret = {
        "out_put": pickle_encode(out_put),
        "cmp": pickle_encode(cmp),
        "od": od,
        "merged_df": pickle_encode(merged_df),
        "trade_record": pickle_encode(trade_record),
        "real_trade": pickle_encode(real_trade),
        "echart_option": option
    }
    return jsonify(ret)


def pickle_encode(data):
    return base64.b64encode(pickle.dumps(data)).decode('utf-8')


def compare(app_, stg, cmp_time=1):
    """
    实现，对比页面
    :return:
    """
    table_trade_name = app_.config[f'stg_{stg}_table_trade']
    table_signal_name = app_.config[f'stg_{stg}_table_signal']
    if cmp_time == 1:
        start_time = time.time() - 24 * 60 * 60
    elif cmp_time == 20:
        start_time = time.time() - 20 * 24 * 60 * 60
    else:
        start_time = 0
    engine = db.engine.connect()
    signal, trade = get_rb_trade_data(engine, table_trade_name, table_signal_name, stg.upper(),
                                      start_time=start_time)
    engine.close()

    # 一个信号可能对应多个交易，所以交易id是一个存成字符串的列表，先eval，变成真的列表，再explode就可以一个交易一行了。
    signal['trade_id'] = signal['trade_id'].apply(lambda x: eval(x))
    signal = signal.explode('trade_id')
    del signal['price']
    merged_df = signal.merge(trade, how='left', on='trade_id')
    merged_df.dropna(inplace=True)
    merged_df['time'] = merged_df['time'].astype(float)
    merged_df['date'] = pd.to_datetime(merged_df['time'], unit='s', origin=datetime(1970, 1, 1, 8, 0, 0))
    merged_df['vwap'] = merged_df['price'] * merged_df['volume']

    real_trade = merged_df.groupby(by=['id']).agg({
        'symbol': 'last',
        'strategy_name': 'last',
        'direction': 'last',
        'offset': 'last',
        'volume': 'sum',
        'date': 'mean',
        'vwap': 'sum'
    })
    real_trade['vwap'] /= real_trade['volume']

    trade_record = pd.read_csv(os.path.join(app.config[f'stg_{stg}_path'], app.config['trade_file_name']))
    trade_record['date'] = pd.to_datetime(trade_record['date'])
    trade_record['date_backtest'] = trade_record['date']
    cmp = pd.merge_asof(real_trade, trade_record, on='date', direction='nearest', tolerance=pd.Timedelta(minutes=10),
                        suffixes=('_real', '_backtest'))
    cmp['time_diff'] = cmp['date'] - cmp['date_backtest']
    cmp['time_diff'] = cmp['time_diff'].astype('str')

    # 滑点
    cmp['slippage'] = cmp['vwap'] / cmp['price'] - 1
    cmp.loc[cmp['direction'] == 'SELL', 'slippage'] *= -1
    cmp = cmp.sort_values(by='date', ascending=True)  # 先排序，保证不会类乘反了
    cmp['slippage_cum'] = (cmp['slippage'] + 1).cumprod()
    cmp['slippage_cum'] = cmp['slippage_cum'].map('{:.6f}'.format)
    cmp['slippage'] = cmp['slippage'].fillna(0)
    cmp['slippage'] = cmp['slippage'].map('{:.6f}'.format)

    # 在实盘开始之后的时间找是否存在有回测，没有实盘的交易并输出这些交易的时间
    time_min = cmp.date.min()
    tp = trade_record.loc[trade_record['date_backtest'] >= time_min]
    out_put = []
    for i, row in tp.iterrows():
        if row['date_backtest'] not in cmp['date_backtest'].values:
            out_put.append(row)
    if not out_put:
        out_put = pd.DataFrame()
    else:
        out_put = pd.concat(out_put, axis=1).T  # 在回测中，但不存在实盘中的交易

    od = (
        'date', 'date_backtest', 'time_diff', 'symbol', 'instrument', 'direction', 'offset', 'action', 'vwap', 'price',
        'volume_real', 'volume_backtest', 'slippage', 'slippage_cum'
    )
    cmp = cmp.sort_values(by='date', ascending=False)  # 展示的列的顺序和数据

    merged_df = merged_df.sort_values(by='date', ascending=False)

    trade_record = trade_record.sort_values(by='date', ascending=False)
    real_trade = real_trade.sort_values(by='date', ascending=False)
    return out_put, cmp, od, merged_df, trade_record, real_trade


@app.route("/get_static/<string:filename>")
def get_static(filename):
    stg = request.args.get('stg')
    if not stg or not filename:
        app.logger.error(f"缺少策略名或文件！请检查{stg},{filename}")
    if 'pkl' in filename:
        with open(f"./static/{stg}/{filename}", 'rb') as f:
            res = pickle.load(f)
        return pickle_encode(res)
    else:
        return send_from_directory(f"./static/{stg}/", filename)


if __name__ == "__main__":
    app.run(host='0.0.0.0', port=8000)